{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n作者：英俊\\nQQ:2227495940\\n邮箱 同上\\n'"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "作者：英俊\n",
    "QQ:2227495940\n",
    "邮箱 同上\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#导入Pytorch包\n",
    "\n",
    "import torch\n",
    "\n",
    "import torch.nn as nn\n",
    "\n",
    "from fastNLP.io.loader import CSVLoader\n",
    "\n",
    "dataset_loader = CSVLoader(headers=('raw_words','target'), sep='\\t')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#看来分隔符很重要啊\n",
    "testset_loader = CSVLoader( headers=['raw_words'],sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 表示将CSV文件中每一行的第一项将填入'raw_words' field，第二项填入'target' field。\n",
    "\n",
    "# 其中项之间由'\\t'分割开来\n",
    "\n",
    "train_path=r'train_shuffle.txt'\n",
    "\n",
    "test_path=r'new_test_handout.txt'\n",
    "\n",
    "dataset = dataset_loader._load(train_path)\n",
    "\n",
    "testset = testset_loader._load(test_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.39\n"
     ]
    }
   ],
   "source": [
    "# 将句子分成单词形式, 详见DataSet.apply()方法\n",
    "\n",
    "import jieba\n",
    "\n",
    "from itertools import chain\n",
    "\n",
    "print(jieba.__version__)\n",
    "# from itertools import chain\n",
    "\n",
    "#     '''\n",
    "\n",
    "#     @params:\n",
    "\n",
    "#         data: 数据的列表，列表中的每个元素为 [文本字符串，0/1标签] 二元组\n",
    "\n",
    "#     @return: 切分词后的文本的列表，列表中的每个元素为切分后的词序列\n",
    "\n",
    "#     '''\n",
    "\n",
    "def get_tokenized(data,words=True):\n",
    "    def tokenizer(text):\n",
    "        return [tok for tok in jieba.cut(text, cut_all=False)]\n",
    "    if words:\n",
    "\n",
    "        #按词语进行编码\n",
    "\n",
    "        return tokenizer(data)\n",
    "\n",
    "    else:\n",
    "\n",
    "        #按字进行编码\n",
    "\n",
    "        return [tokenizer(review) for review in data]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------------------------+--------+\n",
      "| raw_words                           | target |\n",
      "+-------------------------------------+--------+\n",
      "| 那是什么破酒店，你们携程也对外推... | 0      |\n",
      "| 非常一般的酒店，服务台上挂的特色... | 0      |\n",
      "| 入住了传说中的04房间，景色真的很... | 1      |\n",
      "| 餐厅很差,菜的种类水准都不行.        | 0      |\n",
      "|                                     |        |\n",
      "| 酒...                               |        |\n",
      "| 酒店位于福州最热闹的东街口，出来... | 1      |\n",
      "| 上月入住，将近500元一天的房费，...  | 0      |\n",
      "| 老掉牙的五星酒店，我们入住的客房... | 0      |\n",
      "| 酒店位置在很市中心的地方,房间也...  | 1      |\n",
      "| 除了停车不是很方便，其他的都还不... | 1      |\n",
      "| 原先看了别人的评价，做好了不怎么... | 1      |\n",
      "| 宜宾还有很多好的酒店，比如说叙府... | 0      |\n",
      "| 不错,变成常驻京酒店了,就是餐厅的... | 1      |\n",
      "| 最近变成了四星，好像房间装修过了... | 0      |\n",
      "| 一分钟一块钱的网费，贵的惊人。三... | 0      |\n",
      "| 四个字，“糟糕透顶”。我是晚上十...   | 0      |\n",
      "| 我住的是大床房，房间很大，还有一... | 1      |\n",
      "| 几个月之后再次入住，一进门大厅里... | 1      |\n",
      "| 这次和老公入住这家酒店 640元的豪... | 0      |\n",
      "| 忘了说，前台服务较烂，中午在莫泰... | 0      |\n",
      "| 房间价格奇贵，上网费2元每分钟，...  | 0      |\n",
      "| 不错，离大连港很近，步行几分钟就... | 1      |\n",
      "| 前台的服务很差，10月1日的值班经...  | 0      |\n",
      "| 酒店邻近山脚，空气清新，闹市取静... | 1      |\n",
      "| 这是北京能见到的性价比最差的酒店... | 0      |\n",
      "| 房间比中州皇冠之类的大多了，服务... | 1      |\n",
      "| 这家YMCA非常好，还主动帮我们升级... | 1      |\n",
      "| 便宜，物有所值。                    | 1      |\n",
      "|                                     |        |\n",
      "| 早餐：相对比较...                   |        |\n",
      "| 只能说，还可以住住。硬件，环境没... | 0      |\n",
      "| 优点:靠近商业中心,离火车站不远,...  | 0      |\n",
      "| 服务不是很专业，处理速度太慢。不... | 1      |\n",
      "| 酒店的位置不错，附近都靠近购物中... | 1      |\n",
      "| ...                                 | ...    |\n",
      "+-------------------------------------+--------+\n"
     ]
    }
   ],
   "source": [
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------+--------+----------------------+----------------------+\n",
      "| raw_words            | target | chars                | words                |\n",
      "+----------------------+--------+----------------------+----------------------+\n",
      "| 那是什么破酒店，...  | 0      | ['那', '是', '什...  | ['那', '是', '什...  |\n",
      "| 非常一般的酒店，...  | 0      | ['非常', '一般',...  | ['非常', '一般',...  |\n",
      "| 入住了传说中的04...  | 1      | ['入住', '了', '...  | ['入住', '了', '...  |\n",
      "| 餐厅很差,菜的种类... | 0      | ['餐厅', '很差',...  | ['餐厅', '很差',...  |\n",
      "| 酒店位于福州最热...  | 1      | ['酒店', '位于',...  | ['酒店', '位于',...  |\n",
      "| 上月入住，将近50...  | 0      | ['上', '月', '入...  | ['上', '月', '入...  |\n",
      "| 老掉牙的五星酒店...  | 0      | ['老掉牙', '的',...  | ['老掉牙', '的',...  |\n",
      "| 酒店位置在很市中...  | 1      | ['酒店', '位置',...  | ['酒店', '位置',...  |\n",
      "| 除了停车不是很方...  | 1      | ['除了', '停车',...  | ['除了', '停车',...  |\n",
      "| 原先看了别人的评...  | 1      | ['原先', '看', '...  | ['原先', '看', '...  |\n",
      "| 宜宾还有很多好的...  | 0      | ['宜宾', '还有',...  | ['宜宾', '还有',...  |\n",
      "| 不错,变成常驻京酒... | 1      | ['不错', ',', '变... | ['不错', ',', '变... |\n",
      "| 最近变成了四星，...  | 0      | ['最近', '变成',...  | ['最近', '变成',...  |\n",
      "| 一分钟一块钱的网...  | 0      | ['一分钟', '一块...  | ['一分钟', '一块...  |\n",
      "| 四个字，“糟糕透...   | 0      | ['四个', '字', '...  | ['四个', '字', '...  |\n",
      "| 我住的是大床房，...  | 1      | ['我', '住', '的...  | ['我', '住', '的...  |\n",
      "| 几个月之后再次入...  | 1      | ['几个', '月', '...  | ['几个', '月', '...  |\n",
      "| 这次和老公入住这...  | 0      | ['这次', '和', '...  | ['这次', '和', '...  |\n",
      "| 忘了说，前台服务...  | 0      | ['忘', '了', '说...  | ['忘', '了', '说...  |\n",
      "| 房间价格奇贵，上...  | 0      | ['房间', '价格',...  | ['房间', '价格',...  |\n",
      "| 不错，离大连港很...  | 1      | ['不错', '，', '...  | ['不错', '，', '...  |\n",
      "| 前台的服务很差，...  | 0      | ['前台', '的', '...  | ['前台', '的', '...  |\n",
      "| 酒店邻近山脚，空...  | 1      | ['酒店', '邻近',...  | ['酒店', '邻近',...  |\n",
      "| 这是北京能见到的...  | 0      | ['这是', '北京',...  | ['这是', '北京',...  |\n",
      "| 房间比中州皇冠之...  | 1      | ['房间', '比', '...  | ['房间', '比', '...  |\n",
      "| 这家YMCA非常好，...  | 1      | ['这家', 'YMCA',...  | ['这家', 'YMCA',...  |\n",
      "| 便宜，物有所值。...  | 1      | ['便宜', '，', '...  | ['便宜', '，', '...  |\n",
      "| 只能说，还可以住...  | 0      | ['只能', '说', '...  | ['只能', '说', '...  |\n",
      "| 优点:靠近商业中心... | 0      | ['优点', ':', '靠... | ['优点', ':', '靠... |\n",
      "| 服务不是很专业，...  | 1      | ['服务', '不是',...  | ['服务', '不是',...  |\n",
      "| 酒店的位置不错，...  | 1      | ['酒店', '的', '...  | ['酒店', '的', '...  |\n",
      "| ...                  | ...    | ...                  | ...                  |\n",
      "+----------------------+--------+----------------------+----------------------+\n"
     ]
    }
   ],
   "source": [
    "dataset.apply(lambda ins:get_tokenized(ins['raw_words']), new_field_name='words', is_input=True)\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------+--------+-------------------+-------------------+---------+\n",
      "| raw_words         | target | chars             | words             | seq_len |\n",
      "+-------------------+--------+-------------------+-------------------+---------+\n",
      "| 那是什么破酒店... | 0      | ['那', '是', ...  | ['那', '是', ...  | 330     |\n",
      "| 非常一般的酒店... | 0      | ['非常', '一般... | ['非常', '一般... | 36      |\n",
      "| 入住了传说中的... | 1      | ['入住', '了'...  | ['入住', '了'...  | 69      |\n",
      "| 餐厅很差,菜的...  | 0      | ['餐厅', '很差... | ['餐厅', '很差... | 35      |\n",
      "| 酒店位于福州最... | 1      | ['酒店', '位于... | ['酒店', '位于... | 81      |\n",
      "| 上月入住，将近... | 0      | ['上', '月', ...  | ['上', '月', ...  | 98      |\n",
      "| 老掉牙的五星酒... | 0      | ['老掉牙', '的... | ['老掉牙', '的... | 42      |\n",
      "| 酒店位置在很市... | 1      | ['酒店', '位置... | ['酒店', '位置... | 25      |\n",
      "| 除了停车不是很... | 1      | ['除了', '停车... | ['除了', '停车... | 30      |\n",
      "| 原先看了别人的... | 1      | ['原先', '看'...  | ['原先', '看'...  | 235     |\n",
      "| 宜宾还有很多好... | 0      | ['宜宾', '还有... | ['宜宾', '还有... | 22      |\n",
      "| 不错,变成常驻...  | 1      | ['不错', ',',...  | ['不错', ',',...  | 13      |\n",
      "| 最近变成了四星... | 0      | ['最近', '变成... | ['最近', '变成... | 25      |\n",
      "| 一分钟一块钱的... | 0      | ['一分钟', '一... | ['一分钟', '一... | 127     |\n",
      "| 四个字，“糟糕...  | 0      | ['四个', '字'...  | ['四个', '字'...  | 125     |\n",
      "| 我住的是大床房... | 1      | ['我', '住', ...  | ['我', '住', ...  | 63      |\n",
      "| 几个月之后再次... | 1      | ['几个', '月'...  | ['几个', '月'...  | 37      |\n",
      "| 这次和老公入住... | 0      | ['这次', '和'...  | ['这次', '和'...  | 220     |\n",
      "| 忘了说，前台服... | 0      | ['忘', '了', ...  | ['忘', '了', ...  | 67      |\n",
      "| 房间价格奇贵，... | 0      | ['房间', '价格... | ['房间', '价格... | 68      |\n",
      "| 不错，离大连港... | 1      | ['不错', '，'...  | ['不错', '，'...  | 160     |\n",
      "| 前台的服务很差... | 0      | ['前台', '的'...  | ['前台', '的'...  | 44      |\n",
      "| 酒店邻近山脚，... | 1      | ['酒店', '邻近... | ['酒店', '邻近... | 108     |\n",
      "| 这是北京能见到... | 0      | ['这是', '北京... | ['这是', '北京... | 188     |\n",
      "| 房间比中州皇冠... | 1      | ['房间', '比'...  | ['房间', '比'...  | 47      |\n",
      "| 这家YMCA非常好... | 1      | ['这家', 'YMC...  | ['这家', 'YMC...  | 40      |\n",
      "| 便宜，物有所值... | 1      | ['便宜', '，'...  | ['便宜', '，'...  | 12      |\n",
      "| 只能说，还可以... | 0      | ['只能', '说'...  | ['只能', '说'...  | 63      |\n",
      "| 优点:靠近商业...  | 0      | ['优点', ':',...  | ['优点', ':',...  | 79      |\n",
      "| 服务不是很专业... | 1      | ['服务', '不是... | ['服务', '不是... | 28      |\n",
      "| 酒店的位置不错... | 1      | ['酒店', '的'...  | ['酒店', '的'...  | 53      |\n",
      "| ...               | ...    | ...               | ...               | ...     |\n",
      "+-------------------+--------+-------------------+-------------------+---------+\n"
     ]
    }
   ],
   "source": [
    "dataset.apply(lambda ins: len(ins['words']) ,new_field_name='seq_len', is_input=True)\n",
    "\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------+--------+-------------------+-------------------+---------+\n",
      "| raw_words         | target | chars             | words             | seq_len |\n",
      "+-------------------+--------+-------------------+-------------------+---------+\n",
      "| 那是什么破酒店... | 0      | ['那', '是', ...  | ['那', '是', ...  | 330     |\n",
      "| 非常一般的酒店... | 0      | ['非常', '一般... | ['非常', '一般... | 36      |\n",
      "| 入住了传说中的... | 1      | ['入住', '了'...  | ['入住', '了'...  | 69      |\n",
      "| 餐厅很差,菜的...  | 0      | ['餐厅', '很差... | ['餐厅', '很差... | 35      |\n",
      "| 酒店位于福州最... | 1      | ['酒店', '位于... | ['酒店', '位于... | 81      |\n",
      "| 上月入住，将近... | 0      | ['上', '月', ...  | ['上', '月', ...  | 98      |\n",
      "| 老掉牙的五星酒... | 0      | ['老掉牙', '的... | ['老掉牙', '的... | 42      |\n",
      "| 酒店位置在很市... | 1      | ['酒店', '位置... | ['酒店', '位置... | 25      |\n",
      "| 除了停车不是很... | 1      | ['除了', '停车... | ['除了', '停车... | 30      |\n",
      "| 原先看了别人的... | 1      | ['原先', '看'...  | ['原先', '看'...  | 235     |\n",
      "| 宜宾还有很多好... | 0      | ['宜宾', '还有... | ['宜宾', '还有... | 22      |\n",
      "| 不错,变成常驻...  | 1      | ['不错', ',',...  | ['不错', ',',...  | 13      |\n",
      "| 最近变成了四星... | 0      | ['最近', '变成... | ['最近', '变成... | 25      |\n",
      "| 一分钟一块钱的... | 0      | ['一分钟', '一... | ['一分钟', '一... | 127     |\n",
      "| 四个字，“糟糕...  | 0      | ['四个', '字'...  | ['四个', '字'...  | 125     |\n",
      "| 我住的是大床房... | 1      | ['我', '住', ...  | ['我', '住', ...  | 63      |\n",
      "| 几个月之后再次... | 1      | ['几个', '月'...  | ['几个', '月'...  | 37      |\n",
      "| 这次和老公入住... | 0      | ['这次', '和'...  | ['这次', '和'...  | 220     |\n",
      "| 忘了说，前台服... | 0      | ['忘', '了', ...  | ['忘', '了', ...  | 67      |\n",
      "| 房间价格奇贵，... | 0      | ['房间', '价格... | ['房间', '价格... | 68      |\n",
      "| 不错，离大连港... | 1      | ['不错', '，'...  | ['不错', '，'...  | 160     |\n",
      "| 前台的服务很差... | 0      | ['前台', '的'...  | ['前台', '的'...  | 44      |\n",
      "| 酒店邻近山脚，... | 1      | ['酒店', '邻近... | ['酒店', '邻近... | 108     |\n",
      "| 这是北京能见到... | 0      | ['这是', '北京... | ['这是', '北京... | 188     |\n",
      "| 房间比中州皇冠... | 1      | ['房间', '比'...  | ['房间', '比'...  | 47      |\n",
      "| 这家YMCA非常好... | 1      | ['这家', 'YMC...  | ['这家', 'YMC...  | 40      |\n",
      "| 便宜，物有所值... | 1      | ['便宜', '，'...  | ['便宜', '，'...  | 12      |\n",
      "| 只能说，还可以... | 0      | ['只能', '说'...  | ['只能', '说'...  | 63      |\n",
      "| 优点:靠近商业...  | 0      | ['优点', ':',...  | ['优点', ':',...  | 79      |\n",
      "| 服务不是很专业... | 1      | ['服务', '不是... | ['服务', '不是... | 28      |\n",
      "| 酒店的位置不错... | 1      | ['酒店', '的'...  | ['酒店', '的'...  | 53      |\n",
      "| ...               | ...    | ...               | ...               | ...     |\n",
      "+-------------------+--------+-------------------+-------------------+---------+\n"
     ]
    }
   ],
   "source": [
    "dataset.apply(lambda x: int(x['target']), new_field_name='target', is_target=True)\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------------------------+---------------------------+---------+\n",
      "| raw_words                 | words                     | seq_len |\n",
      "+---------------------------+---------------------------+---------+\n",
      "| 入住大连富丽华东楼，新... | ['入住', '大连', '富丽... | 17      |\n",
      "| 6月初入住的，据说后楼...  | ['6', '月初', '入住',...  | 161     |\n",
      "| 三人间，房间很宽敞，卫... | ['三人间', '，', '房间... | 76      |\n",
      "| 很好．非常干净．是新装... | ['很', '好', '．', '非... | 29      |\n",
      "| 4月份回青岛顺便带朋友...  | ['4', '月份', '回', '...  | 66      |\n",
      "| 我是看了携程的评价才去... | ['我', '是', '看', '了... | 138     |\n",
      "| 早上八点多到酒店,预先...  | ['早上', '八点', '多到... | 15      |\n",
      "| 该酒店根本就达不到星级... | ['该', '酒店', '根本'...  | 80      |\n",
      "| 订的单人间，结果没宽带... | ['订', '的', '单人间'...  | 52      |\n",
      "| 差的不能在差了，厕所垃... | ['差', '的', '不能', ...  | 40      |\n",
      "| 潮州迎宾馆确实装修了，... | ['潮州', '迎宾馆', '确... | 49      |\n",
      "| 我订了二个房间。酒店隔... | ['我订', '了', '二个'...  | 199     |\n",
      "| 出于不信邪的原因，入住... | ['出于', '不信邪', '的... | 79      |\n",
      "| 我订的是430元的大床间...  | ['我订', '的', '是', ...  | 285     |\n",
      "| 前台服務員的禮貌有待改... | ['前台', '服務員', '的... | 97      |\n",
      "| 入住的5天，天天洗桑拿...  | ['入住', '的', '5', '...  | 308     |\n",
      "| 房间装修有点旧了，隔音... | ['房间', '装修', '有点... | 31      |\n",
      "| 不能说没有一点点问题，... | ['不能', '说', '没有'...  | 38      |\n",
      "| 总体说来,应该是二星级...  | ['总体', '说来', ',',...  | 62      |\n",
      "| 这个酒店总体不错,安排...  | ['这个', '酒店', '总体... | 22      |\n",
      "| 隔音很差，交通尚可，设... | ['隔音', '很差', '，'...  | 17      |\n",
      "| 不错呀,房间够大,因为带... | ['不错呀', ',', '房间...  | 53      |\n",
      "| 在沙田,地理位置也不算...  | ['在', '沙田', ',', '...  | 59      |\n",
      "| 房间保养不错.前台服务...  | ['房间', '保养', '不错... | 21      |\n",
      "| 大床房还不错，其他的没... | ['大床', '房', '还', ...  | 20      |\n",
      "| 地段好，到九龙坐酒店门... | ['地段', '好', '，', ...  | 69      |\n",
      "| 差到一定程度了 不如一...  | ['差到', '一定', '程度... | 15      |\n",
      "| 服务太差了,堂堂的五星...  | ['服务', '太差', '了'...  | 38      |\n",
      "| 同样的价格相差一个月，... | ['同样', '的', '价格'...  | 16      |\n",
      "| 酒店确实很差，而且价格... | ['酒店', '确实', '很差... | 24      |\n",
      "| 告知所有的朋友，这个酒... | ['告知', '所有', '的'...  | 66      |\n",
      "| ...                       | ...                       | ...     |\n",
      "+---------------------------+---------------------------+---------+\n"
     ]
    }
   ],
   "source": [
    "#testset.apply(lambda ins: list(chain.from_iterable(get_tokenized(ins['raw_words']))), new_field_name='words', is_input=True)\n",
    "\n",
    "testset.apply(lambda ins: get_tokenized(ins['raw_words']), new_field_name='words', is_input=True)\n",
    "\n",
    "testset.apply(lambda ins: len(ins['words']) ,new_field_name='seq_len',is_input=True)\n",
    "print(testset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "###\n",
    "\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "#将DataSet按照ratio的比例拆分，返回两个DataSet\n",
    "\n",
    "#ratio (float) -- 0<ratio<1, 返回的第一个DataSet拥有 (1-ratio) 这么多数据，第二个DataSet拥有`ratio`这么多数据\n",
    "\n",
    "train_data, dev_data = dataset.split(0.1, shuffle=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------+--------+-------------------+-------------------+---------+\n",
      "| raw_words         | target | chars             | words             | seq_len |\n",
      "+-------------------+--------+-------------------+-------------------+---------+\n",
      "| 这是我见过最垃... | 0      | ['这', '是', ...  | ['这', '是', ...  | 136     |\n",
      "| 设施很陈旧，卫... | 0      | ['设施', '很'...  | ['设施', '很'...  | 12      |\n",
      "| 携程的员工最好... | 0      | ['携程', '的'...  | ['携程', '的'...  | 158     |\n",
      "| 很好，这次入住... | 1      | ['很', '好', ...  | ['很', '好', ...  | 93      |\n",
      "| 因为春节参加朋... | 0      | ['因为', '春节... | ['因为', '春节... | 74      |\n",
      "| 前台服务较差，... | 1      | ['前台', '服务... | ['前台', '服务... | 32      |\n",
      "| 比较老的酒店，... | 1      | ['比较', '老'...  | ['比较', '老'...  | 20      |\n",
      "| 我住的是大床间... | 0      | ['我', '住', ...  | ['我', '住', ...  | 40      |\n",
      "| 我是在海洋公园... | 1      | ['我', '是', ...  | ['我', '是', ...  | 355     |\n",
      "| 回来2天了，很...  | 1      | ['回来', '2',...  | ['回来', '2',...  | 724     |\n",
      "| 性价比低，感觉... | 0      | ['性价比', '低... | ['性价比', '低... | 47      |\n",
      "| 住宿环境还是不... | 1      | ['住宿', '环境... | ['住宿', '环境... | 105     |\n",
      "| 2007年9月份曾...  | 1      | ['2007', '年'...  | ['2007', '年'...  | 53      |\n",
      "| 房间很大，交通... | 1      | ['房间', '很大... | ['房间', '很大... | 44      |\n",
      "| 酒店设施很差，... | 0      | ['酒店设施', ...  | ['酒店设施', ...  | 52      |\n",
      "| 地理位置很好，... | 1      | ['地理位置', ...  | ['地理位置', ...  | 36      |\n",
      "| 房价要四百多，... | 0      | ['房价', '要'...  | ['房价', '要'...  | 37      |\n",
      "| 酒店设施起码有... | 0      | ['酒店设施', ...  | ['酒店设施', ...  | 145     |\n",
      "| 服务态度很好，... | 1      | ['服务态度', ...  | ['服务态度', ...  | 17      |\n",
      "| 房间太旧了， ...  | 1      | ['房间', '太旧... | ['房间', '太旧... | 34      |\n",
      "| 其他没什么，就... | 0      | ['其他', '没什... | ['其他', '没什... | 52      |\n",
      "| 比较老的五星级... | 1      | ['比较', '老'...  | ['比较', '老'...  | 15      |\n",
      "| 在大连四星级酒... | 1      | ['在', '大连'...  | ['在', '大连'...  | 95      |\n",
      "| 有可能是世界上... | 0      | ['有', '可能'...  | ['有', '可能'...  | 39      |\n",
      "| 非常好。国内最... | 1      | ['非常', '好'...  | ['非常', '好'...  | 14      |\n",
      "| 很不好的酒店,...  | 0      | ['很', '不好'...  | ['很', '不好'...  | 236     |\n",
      "| 总的来说还过的... | 1      | ['总的来说', ...  | ['总的来说', ...  | 116     |\n",
      "| 尽管是四星的宾... | 0      | ['尽管', '是'...  | ['尽管', '是'...  | 44      |\n",
      "| 1。我住的是靠...  | 0      | ['1', '。', '...  | ['1', '。', '...  | 116     |\n",
      "| 如果你住习惯了... | 0      | ['如果', '你'...  | ['如果', '你'...  | 100     |\n",
      "| 说是五星，房间... | 1      | ['说', '是', ...  | ['说', '是', ...  | 23      |\n",
      "| ...               | ...    | ...               | ...               | ...     |\n",
      "+-------------------+--------+-------------------+-------------------+---------+\n"
     ]
    }
   ],
   "source": [
    "print(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1080 120 800\n"
     ]
    }
   ],
   "source": [
    "print(len(train_data),len(dev_data),len(testset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Vocabulary(['那', '是', '什么', '破', '酒店']...)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab = Vocabulary(min_freq=2).from_dataset(dataset, field_name='words')\n",
    "\n",
    "vocab.index_dataset(train_data, dev_data, testset, field_name='words', new_field_name='words')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 709 out of 5160 words in the pre-training embedding.\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import StaticEmbedding,StackEmbedding\n",
    "\n",
    "fastnlp_embed = StaticEmbedding(vocab, model_dir_or_name='cn-char-fastnlp-100d',min_freq=2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNNText(\n",
      "  (embed): Embedding(\n",
      "    (embed): StaticEmbedding(\n",
      "      (dropout_layer): Dropout(p=0)\n",
      "      (embedding): Embedding(5160, 100, padding_idx=0)\n",
      "    )\n",
      "    (dropout): Dropout(p=0.0)\n",
      "  )\n",
      "  (conv_pool): ConvMaxpool(\n",
      "    (convs): ModuleList(\n",
      "      (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
      "      (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "      (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
      "    )\n",
      "  )\n",
      "  (dropout): Dropout(p=0.1)\n",
      "  (fc): Linear(in_features=120, out_features=2, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.models import CNNText\n",
    "\n",
    "model_CNN = CNNText(fastnlp_embed, num_classes=2,dropout=0.1)\n",
    "\n",
    "print(model_CNN)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input fields after batch(if batch size is 2):\n",
      "\tchars: (1)type:numpy.ndarray (2)dtype:object, (3)shape:(2,) \n",
      "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 136]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "target fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\n",
      "training epochs started 2020-06-18-10-42-05-446458\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=340.0), HTML(value='')), layout=Layout(di…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.6 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 1/10. Step:34/340: \n",
      "\r",
      "AccuracyMetric: acc=0.833333\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.63 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 2/10. Step:68/340: \n",
      "\r",
      "AccuracyMetric: acc=0.9\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.59 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 3/10. Step:102/340: \n",
      "\r",
      "AccuracyMetric: acc=0.858333\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.61 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 4/10. Step:136/340: \n",
      "\r",
      "AccuracyMetric: acc=0.925\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.72 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 5/10. Step:170/340: \n",
      "\r",
      "AccuracyMetric: acc=0.908333\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.59 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 6/10. Step:204/340: \n",
      "\r",
      "AccuracyMetric: acc=0.916667\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.68 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 7/10. Step:238/340: \n",
      "\r",
      "AccuracyMetric: acc=0.908333\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.76 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 8/10. Step:272/340: \n",
      "\r",
      "AccuracyMetric: acc=0.908333\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.58 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 9/10. Step:306/340: \n",
      "\r",
      "AccuracyMetric: acc=0.916667\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.63 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 10/10. Step:340/340: \n",
      "\r",
      "AccuracyMetric: acc=0.908333\n",
      "\n",
      "\r",
      "Reloaded the best model.\n",
      "\n",
      "In Epoch:4/Step:136, got best dev performance:\n",
      "AccuracyMetric: acc=0.925\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'best_epoch': 4,\n",
       " 'best_eval': {'AccuracyMetric': {'acc': 0.925}},\n",
       " 'best_step': 136,\n",
       " 'seconds': 124.97}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric,BCELoss\n",
    "\n",
    "trainer_CNN = Trainer(model=model_CNN, train_data=train_data, dev_data=dev_data,loss=CrossEntropyLoss(), metrics=AccuracyMetric())\n",
    "\n",
    "trainer_CNN.train()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "demo=[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "#批量进行数据预测\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "\n",
    "def batch_predict(model,data):\n",
    "    submission=pd.DataFrame(columns=['Prediction'])\n",
    "#    submission = pd.DataFrame(columns=['ID','Prediction'])\n",
    "    for i in range(len(data)):\n",
    "    #for i in range(5):\n",
    "#         print(data.words[i])\n",
    "        tensor = torch.tensor(data.words[i])\n",
    "        pred = model.predict(tensor.view(1,-1))\n",
    "#         print(pred)\n",
    "        prob = pred['pred'].numpy()[0]\n",
    "#         print(\"pred:%.2f\"%(prob))\n",
    "#         print('='*50)\n",
    "#         print(type(prob))\n",
    "        s2 = pd.Series([float(prob)], index=['Prediction'])\n",
    "        demo.append(prob)\n",
    "#         print(s2)\n",
    "        submission = submission.append(s2, ignore_index=True)\n",
    "        submission['Prediction'] = submission.Prediction .astype(int)\n",
    "#         submission['']\n",
    "#         submission['Prediction'] = submission.Prediction.astype(float) \n",
    "    return submission\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "#开始进行预测，并将结果保存到提交格式文件中，提交平台\n",
    "\n",
    "# summission_path = r'data\\Comments9120'\n",
    "\n",
    "submission = batch_predict(model_CNN,testset)\n",
    "submission.to_csv('fastnlpDemo.csv',encoding='utf-8')\n",
    "# submission.to_csv(summission_path+'\\submission-CNN-20200229-words.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
